#  + - - - - - -
#  |  run with: snakemake final_rule --cores 4 --config ...
#  + - - - 

import os
import yaml

MODELS = ["Rich", "Muon", "GlobalPIDmu", "GlobalPIDh"] 
PARTS = ["muon", "pion", "kaon", "proton"]
GCC_FLAGS = ["-O3", "-lm", "--shared", "-fPIC"]

# +-------------------+
# |   Initial setup   |
# +-------------------+

configfile: "config/snakemake_config.yml"

with open("../config/directories.yml") as file:
    config_dir = yaml.full_load(file)

models_dir = config_dir["models_dir"]
exports_dir = config_dir["exports_dir"]

if config["debug"]:
    GCC_FLAGS.append("-DDEBUG")

wildcard_constraints:
    model = "|".join([m for m in MODELS]),
    part = "|".join([p for p in PARTS]),

# +------------------------------+
# |   GAN models transpilation   |
# +------------------------------+

rule deploy_gan_models:
    input:
        os.path.join(models_dir, "{model}_{part}_models/{label}_{model}GAN-{part}_{sample}_model"),

    output:
        temp("/tmp/{model}-{part}_{sample}_{label}-gan.C"),

    log:
        "logs/deploy-gan-models_{model}-{part}_{sample}_{label}.log"

    shell:
        "python deploy_gan_models.py -m {wildcards.model} -p {wildcards.part} -D {wildcards.sample} -M {input} > {log}"

# +---------------------------------+
# |   isMuon models transpilation   |
# +---------------------------------+

rule deploy_ann_models:
    input:
        os.path.join(models_dir, "isMuon_{part}_models/{label}_isMuonANN-{part}_{sample}_model"),

    output:
        temp("/tmp/isMuon-{part}_{sample}_{label}-ann.C"),

    log:
        "logs/deploy-ann-models_isMuon-{part}_{sample}_{label}.log"

    shell:
        "python deploy_ann_models.py -m isMuon -p {wildcards.part} -D {wildcards.sample} -M {input} > {log}"

# +----------------------------+
# |    Pipeline compilation    |
# +----------------------------+

rule compile_full_pipes:
    input:
        gan_model = expand(
            "/tmp/{model}-{part}_{sample}_{label}-gan.C",
            model=MODELS,
            part=PARTS,
            sample=config["data_sample"],
            label=config["label"],
            allow_missing=True,
        ),
        ismuon_model = expand(
            "/tmp/isMuon-{part}_{sample}_{label}-ann.C",
            part=PARTS,
            sample=config["data_sample"],
            label=config["label"],
            allow_missing=True,
        ),
        full_pipeline = "cpipelines/full_pipelines.C",

    output:
        os.path.join(exports_dir, "CompiledModel_{sample}_{label}.so"),

    params:
        flags = GCC_FLAGS,

    log:
        "logs/compile-full-pipes_{sample}_{label}.log"

    shell:
        "gcc {params.flags} -o {output} {input} > {log}"

# +--------------------------+
# |    Test transpilation    |
# +--------------------------+

rule test_transpilation:
    input:
        os.path.join(exports_dir, "CompiledModel_{sample}_{label}.so"),

    params:
        max_q_err = config["max_q_err"],
        err_patience = config["err_patience"],

    log:
        "logs/test-transpilation_{part}_{sample}_{label}.log"

    shell:
        "python tests/test.py -o {input} -p {wildcards.part} -D {wildcards.sample} -E {params.max_q_err} -P {params.err_patience} --figure > {log}"


# +------------------+
# |    Final rule    |
# +------------------+

rule final_rule:
    input:
        expand(
            os.path.join(exports_dir, "CompiledModel_{sample}_{label}.so"),
            sample=config["data_sample"],
            label=config["label"],
        ),
        expand(
            "logs/test-transpilation_{part}_{sample}_{label}.log",
            part=PARTS,
            sample=config["data_sample"],
            label=config["label"],
        )

    output:
        ".all_rules_completed"

    shell:
        "touch {output}"
